# import packages and modules
import pandas as pd
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
import numpy as np
import sklearn as skl
import os
import sys
import pickle
from joblib import dump, load
import yaml
import statsmodels.api as sm 
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.linear_model import LinearRegression as reg
import pdb
from pprint import pprint
from statistics import mean
import random
random.seed(50)
import math
import argparse
from tqdm import tqdm
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score

defaultaxpayer_id='/REDACTED/fairness/code/rf/data/'
defaultout='/REDACTED/'

## linear estimator
def regEstimate(dataset, pbvarb, outcome, wvarb):
    model = smf.wls(outcome + '~' + pbvarb, dataset, weights = dataset[wvarb]).fit(cov_type = 'HC1')
    coef = model.params[pbvarb]
    se = model.bse[pbvarb]
    return coef,se

## probabilistic estimator
def chenEstimate(dataset, pbvarb, outcome, wvarb):
    black_audit_rate = (dataset[pbvarb]*dataset[outcome]*dataset[wvarb]).sum()/(dataset[pbvarb]*dataset[wvarb]).sum()
    nonblack_audit_rate = ((1-dataset[pbvarb])*dataset[outcome]*dataset[wvarb]).sum()/((1-dataset[pbvarb])*dataset[wvarb]).sum()
    est = black_audit_rate - nonblack_audit_rate
    return est

# functions for determining feature importances
def get_feature_names(configpath='/REDACTED/fairness/code/rf/config/',
                      datapath='/REDACTED/fairness/code/rf/data/',
                      dep_database=False):
    stream=open(configpath+'data-config.yaml', 'r')
    out = yaml.load(stream)
    
    if dep_database==False:
        df=pd.read_csv(datapath+'clean_rf_data.csv')
        feature_vars = [x for x in out['features_str'] if x in df.columns]
        features = df[feature_vars]
    
    elif dep_database==True:
        df=pd.read_csv(datapath+'clean_rf_data_plus_dep_database.csv')
        feature_vars = [x for x in out['features_plus_dep_database_str'] if x in df.columns]
        features = df[feature_vars]
        
    feature_names=features.columns
        
    return feature_names

def get_feature_importance(indir=defaultaxpayer_id,
                            outdir=defaultout,
                            modelname='EITC_NCMP_RF_Class_100_plus_dep_database',
                            feature_names=None):
    model=load(defaultaxpayer_id+modelname+'.joblib')
    importances=pd.Series(model.feature_importances_, index=feature_names)
    importances.nlargest(20).plot(kind='barh', title=modelname).get_figure().savefig(outdir+modelname+'_feat_imp.png', bbox_inches='tight')
    importances.to_csv(outdir+modelname+'_importances.csv')
    return importances

# read in main rf dataset and subset to eitc
full_research_audits_df = pd.read_csv('/REDACTED/fairness/code/rf/data/clean_rf_data_plus_dep_database.csv')
len(full_research_audits_df)
full_research_audits_df = full_research_audits_df[(full_research_audits_df.activity_code == 270) | (full_research_audits_df.activity_code == 271)]
len(full_research_audits_df)

# define noncompliance variable for taxpayers underreportaxpayer_idg by > $100
full_research_audits_df['did_noncomp'] = np.where(full_research_audits_df['chg_in_tax_owed_pv'] > 100, 1, 0)


# calculate variable 'did_very_noncomp', indicator for top 1.45% of underreportaxpayer_idg
x = np.array(full_research_audits_df['chg_in_tax_owed_pv'])
wt = np.array(full_research_audits_df['base_weight'])

sorted_indices = np.argsort(x)
sorted_x = x[sorted_indices]
sorted_wt = wt[sorted_indices]

cumulative_weights = np.cumsum(sorted_wt)
total_weight = np.sum(sorted_wt)

target_percentile = 98.55
target_weight = total_weight * target_percentile / 100

index = np.searchsorted(cumulative_weights, target_weight, side='right')

if index > 0:
    if index < len(sorted_x): #shouldn't we always be in this case?
        alpha = (target_weight - cumulative_weights[index - 1]) / (cumulative_weights[index] - cumulative_weights[index - 1])
        print(alpha)
        percentile_98_55 = sorted_x[index - 1] + alpha * (sorted_x[index] - sorted_x[index - 1])

full_research_audits_df['did_very_noncomp'] = np.where(full_research_audits_df.chg_in_tax_owed_pv > percentile_98_55, 1, 0)
(full_research_audits_df[full_research_audits_df.did_very_noncomp == 1]).base_weight.sum() / full_research_audits_df.base_weight.sum()


# extract 40 most important features from total underreportaxpayer_idg prediction model
feature_names_dep_database = get_feature_names(dep_database=True)
get_feature_importance(modelname='EITC_NCMP_RF_Reg_plus_dep_database_train_set_0', feature_names=feature_names_dep_database)
get_feature_importance(modelname='EITC_NCMP_RF_Reg_plus_dep_database_train_set_1', feature_names=feature_names_dep_database)
get_feature_importance(modelname='EITC_NCMP_RF_Reg_plus_dep_database_train_set_2', feature_names=feature_names_dep_database)
get_feature_importance(modelname='EITC_NCMP_RF_Reg_plus_dep_database_train_set_3', feature_names=feature_names_dep_database)
get_feature_importance(modelname='EITC_NCMP_RF_Reg_plus_dep_database_train_set_4', feature_names=feature_names_dep_database)

fold0 = pd.read_csv('/REDACTED/EITC_NCMP_RF_Reg_plus_dep_database_train_set_0_importances.csv')
fold1 = pd.read_csv('/REDACTED/EITC_NCMP_RF_Reg_plus_dep_database_train_set_1_importances.csv')
fold1 = fold1.rename(columns={'0': '1'})
fold2 = pd.read_csv('/REDACTED/EITC_NCMP_RF_Reg_plus_dep_database_train_set_2_importances.csv')
fold2 = fold2.rename(columns={'0': '2'})
fold3 = pd.read_csv('/REDACTED/EITC_NCMP_RF_Reg_plus_dep_database_train_set_3_importances.csv')
fold3 = fold3.rename(columns={'0': '3'})
fold4 = pd.read_csv('/REDACTED/EITC_NCMP_RF_Reg_plus_dep_database_train_set_4_importances.csv')
fold4 = fold4.rename(columns={'0': '4'})

folds = fold0.merge(fold1, how = 'left', on = 'Unnamed: 0')
folds = folds.merge(fold2, how = 'left', on = 'Unnamed: 0')
folds = folds.merge(fold3, how = 'left', on = 'Unnamed: 0')
folds = folds.merge(fold4, how = 'left', on = 'Unnamed: 0')

# average importances over folds
folds['importance_avg'] = (folds['0'] + folds['1'] + folds['2'] + folds['3'] + folds['4']) / 5

folds = folds.rename(columns={'Unnamed: 0': 'feature'})

importances=pd.Series(folds['importance_avg'].tolist(), index=folds['feature'])

top_40 = importances.nlargest(40).reset_index().feature.tolist()
top_40.sort()
top_40.append(top_40.pop(top_40.index('count_issues')))
top_40.append(top_40.pop(top_40.index('irs_dep_risk_score')))
top_40

top_40_descriptions = [REDACTED. AVAILABLE UPON REQUEST OF AUTHORS AND WITH APPROPRIATE CLEARANCE FROM IRS.]

                     
top_40.remove('activity_code')
top_40.remove('filing_status')

top_40_descriptions.remove('Activity Code')
top_40_descriptions.remove('Filing Status Code')


# merge on fold numbers for train/test split
folds = pd.read_csv('/REDACTED/fold_info.csv')
folds = folds[['taxpayer_id_new', 'fold']]
full_research_audits_df = full_research_audits_df.merge(folds, how = 'left', on = 'taxpayer_id_new')


# loop over folds, settaxpayer_idg it aside for testaxpayer_idg and using the other 4 for training
# within each fold, train a classifier for each model to predict non-compliance
# for each of these classifiers, re-sample the output 1,000 times, each time
# selectaxpayer_idg a group of taxpayers for audit and reportaxpayer_idg disparity and performance metrics

results = {'feature0': [], 'hit_rate0':[],'correlation0':[],'disparity0':[], 'precision0':[], 'mean_noncomp_audited0':[], 'prob_audit_given_noncomp0': [], 'prob_audit_given_very_noncomp0': [], 'audited_wgt0': [],
           'feature1': [], 'hit_rate1':[],'correlation1':[],'disparity1':[], 'precision1':[], 'mean_noncomp_audited1':[], 'prob_audit_given_noncomp1': [], 'prob_audit_given_very_noncomp1': [], 'audited_wgt1': [],
           'feature2': [], 'hit_rate2':[],'correlation2':[],'disparity2':[], 'precision2':[], 'mean_noncomp_audited2':[], 'prob_audit_given_noncomp2': [], 'prob_audit_given_very_noncomp2': [], 'audited_wgt2': [],
           'feature3': [], 'hit_rate3':[],'correlation3':[],'disparity3':[], 'precision3':[], 'mean_noncomp_audited3':[], 'prob_audit_given_noncomp3': [], 'prob_audit_given_very_noncomp3': [], 'audited_wgt3': [],
           'feature4': [], 'hit_rate4':[],'correlation4':[],'disparity4':[], 'precision4':[], 'mean_noncomp_audited4':[], 'prob_audit_given_noncomp4': [], 'prob_audit_given_very_noncomp4': [], 'audited_wgt4': []}
output = {}
col_varbs = top_40
for fold_num in [0,1,2,3,4]:
    print(fold_num)
    full_research_audits_df_fold = full_research_audits_df[full_research_audits_df.fold == fold_num]
    full_research_audits_df_other_folds = full_research_audits_df[full_research_audits_df.fold != fold_num]
    train_x = full_research_audits_df_other_folds[col_varbs+['predicted_prob_black', 'base_weight', 'chg_in_tax_owed_pv', 'did_very_noncomp']]
    train_y = full_research_audits_df_other_folds['did_noncomp']
    test_x = full_research_audits_df_fold[col_varbs+['predicted_prob_black', 'base_weight', 'chg_in_tax_owed_pv', 'did_very_noncomp']]
    test_y = full_research_audits_df_fold['did_noncomp']
    total_weight = test_x.base_weight.sum()
    budget = 0.0145 * total_weight
    for col in col_varbs:
        if col in ['ex_chld', 'count_issues']:
            clf = tree.DecisionTreeClassifier(max_depth=1)
        else:
            clf = tree.DecisionTreeClassifier(max_depth=2)
        clf.fit(np.array(train_x[[col]]).reshape(-1,1),train_y)
        output[col] = clf.predict_proba(np.array(test_x[col]).reshape(-1,1))
        test_x_copy = test_x.copy()
        test_idx = test_x_copy.index
        test_x_copy['pred_prob'] = output[col][:,1]
        test_x_copy['outcome'] = test_y
        features=[]
        disparities=[]
        hit_rates=[]
        correlations=[]
        precisions=[]
        mean_noncomp_auditeds=[]
        prob_audit_given_noncomps=[]
        prob_audit_given_very_noncomps=[]
        audited_wgts=[]
        for j in range(1000):
            print(j)
            test_x_clean = test_x_copy.sample(frac=1, random_state=j)
            test_x_clean = test_x_clean.sort_values('pred_prob',ascending=False)
            test_x_clean['cumsum'] = test_x_clean.base_weight.cumsum()
            test_x_clean['audit'] = test_x_clean['cumsum']<budget
            actual = test_x_clean.loc[test_x_clean.audit==1, 'base_weight'].sum()
            # if there is audit budget remaining, use it on a portion of the
            # next observation up for audit
            if actual/budget < 1:
                row = test_x_clean.index[(test_x_clean.audit.values != 1).argmax()]
                dif = budget - actual
                dupe = test_x_clean.loc[row].copy()
                dupe['audit'] = True
                dupe['base_weight'] = dif
                test_x_clean.loc[row, 'base_weight'] = test_x_clean.loc[row, 'base_weight'] - dif
                test_x_clean = test_x_clean.append(dupe)
            features.append(col)
            disparities.append(chenEstimate(test_x_clean, 'predicted_prob_black', 'audit', 'base_weight'))
            hit_rates.append(test_x_clean[test_x_clean.audit==True].outcome.mean())
            correlations.append(np.corrcoef(output[col][:,1],test_y)[0][1])
            precisions.append(precision_score(test_x_clean['outcome'], test_x_clean['audit']))
            mean_noncomp_auditeds.append(test_x_clean[test_x_clean.audit==True].chg_in_tax_owed_pv.mean())
            prob_audit_given_noncomps.append(test_x_clean[test_x_clean.outcome == 1].audit.mean())
            prob_audit_given_very_noncomps.append(test_x_clean[test_x_clean.did_very_noncomp == 1].audit.mean()) 
            audited_wgts.append((test_x_clean[test_x_clean.audit == True].base_weight.sum()) / (test_x_clean.base_weight.sum()))    
        results['feature' + str(fold_num)].append(features[0])
        results['disparity' + str(fold_num)].append(np.mean(disparities))
        results['hit_rate' + str(fold_num)].append(np.mean(hit_rates))
        results['correlation' + str(fold_num)].append(np.mean(correlations))
        results['precision' + str(fold_num)].append(np.mean(precisions))
        results['mean_noncomp_audited' + str(fold_num)].append(np.mean(mean_noncomp_auditeds))
        results['prob_audit_given_noncomp' + str(fold_num)].append(np.mean(prob_audit_given_noncomps))
        results['prob_audit_given_very_noncomp' + str(fold_num)].append(np.mean(prob_audit_given_very_noncomps))  
        results['audited_wgt' + str(fold_num)].append(np.mean(audited_wgts)) 

# put results into a dataframe
results = pd.DataFrame(results)
results = results.drop(columns=['feature1', 'feature2', 'feature3', 'feature4'])
results = results.rename(columns={'feature0':'feature'})

# average metrics over 5 folds
results['disparity'] = (results['disparity0'] + results['disparity1'] + results['disparity2'] + results['disparity3'] + results['disparity4']) / 5
results['hit_rate'] = (results['hit_rate0'] + results['hit_rate1'] + results['hit_rate2'] + results['hit_rate3'] + results['hit_rate4']) / 5
results['correlation'] = (results['correlation0'] + results['correlation1'] + results['correlation2'] + results['correlation3'] + results['correlation4']) / 5
results['precision'] = (results['precision0'] + results['precision1'] + results['precision2'] + results['precision3'] + results['precision4']) / 5
results['mean_noncomp_audited'] = (results['mean_noncomp_audited0'] + results['mean_noncomp_audited1'] + results['mean_noncomp_audited2'] + results['mean_noncomp_audited3'] + results['mean_noncomp_audited4']) / 5
results['prob_audit_given_noncomp'] = (results['prob_audit_given_noncomp0'] + results['prob_audit_given_noncomp1'] + results['prob_audit_given_noncomp2'] + results['prob_audit_given_noncomp3'] + results['prob_audit_given_noncomp4']) / 5
results['prob_audit_given_very_noncomp'] = (results['prob_audit_given_very_noncomp0'] + results['prob_audit_given_very_noncomp1'] + results['prob_audit_given_very_noncomp2'] + results['prob_audit_given_very_noncomp3'] + results['prob_audit_given_very_noncomp4']) / 5
results['feature_description'] = top_40_descriptions

# write out results to a csv, optional to read in
#results.to_csv('/REDACTED/hadi_results_table_folds_V9.csv', index=False)
#results = pd.read_csv('/REDACTED/hadi_results_table_folds_V9.csv')

# old versions
#results_v4 = pd.read_csv('/REDACTED/hadi_results_table_folds_V4.csv')
#results_v8 = pd.read_csv('/REDACTED/hadi_results_table_folds_V8.csv')


### plot disparity by high-noncomplier discovery rate
results['disparity_pp'] = results['disparity']*100
fig,ax=plt.subplots(figsize=(10,8))
plt.scatter(results['prob_audit_given_very_noncomp'],results['disparity_pp'])
plt.xlabel('High-Noncomplier Discovery Rate', fontsize=16)
plt.ylabel('Disparity of induced prediction at budget', fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
#plt.savefig('/REDACTED/test1_v8.png')
#plt.savefig('/REDACTED/smoking_gun_hadi_3_test_V4.png')
plt.savefig('/REDACTED/disparity_by_high_noncomp_discovery.png')
plt.close()

#### plot disparity by mean detected underreportaxpayer_idg
fig,ax=plt.subplots(figsize=(10,8))
plt.scatter(results['mean_noncomp_audited'],results['disparity_pp'])
plt.xlabel('Mean Detected Underreportaxpayer_idg ($)', fontsize=16)
plt.ylabel('Disparity (probabilistic, percentage points)', fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
#plt.savefig('/REDACTED/test0_v8.png')
#plt.savefig('/REDACTED/smoking_gun_hadi_1_test_V4.png')
plt.savefig('/REDACTED/disparity_by_mean_detected_underreportaxpayer_idg.png')

plt.close()